import torch
import torch.nn as nn
import os
import numpy as np
import random
import json
import jsonlines

import time
import argparse
import sklearn
import traceback

import sys
sys.path.append("..") # 将父目录放入系统路径中
from torch.utils import data
from tqdm import tqdm
from scipy.special import softmax

from ditto_light.ditto import evaluate, DittoModel
from ditto_light.exceptions import ModelNotFoundError
from ditto_light.dataset import DittoDataset
from ditto_light.summarize import Summarizer
from ditto_light.knowledge import *
from sfda.shot_train import shot_train

def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def to_str(ent1, ent2, summarizer=None, max_len=256, dk_injector=None):
    """Serialize a pair of data entries

    Args:
        ent1 (Dictionary): the 1st data entry
        ent2 (Dictionary): the 2nd data entry
        summarizer (Summarizer, optional): the summarization module
        max_len (int, optional): the max sequence length
        dk_injector (DKInjector, optional): the domain-knowledge injector

    Returns:
        string: the serialized version
    """
    content = ''
    for ent in [ent1, ent2]:
        if isinstance(ent, str):
            content += ent
        else:
            for attr in ent.keys():
                content += 'COL %s VAL %s ' % (attr, ent[attr])
        content += '\t'

    content += '0'

    if summarizer is not None:
        content = summarizer.transform(content, max_len=max_len)

    new_ent1, new_ent2, _ = content.split('\t')
    if dk_injector is not None:
        new_ent1 = dk_injector.transform(new_ent1)
        new_ent2 = dk_injector.transform(new_ent2)

    return new_ent1 + '\t' + new_ent2 + '\t0'


def classify(sentence_pairs, model,
             lm='distilbert',
             max_len=256,
             threshold=None):
    """Apply the MRPC model.

    Args:
        sentence_pairs (list of str): the sequence pairs
        model (MultiTaskNet): the model in pytorch
        max_len (int, optional): the max sequence length
        threshold (float, optional): the threshold of the 0's class

    Returns:
        list of float: the scores of the pairs
    """
    inputs = sentence_pairs
    # print('max_len =', max_len)
    dataset = DittoDataset(inputs,
                           max_len=max_len,
                           lm=lm)
    # print(dataset[0])
    iterator = data.DataLoader(dataset=dataset,
                               batch_size=len(dataset),
                               shuffle=False,
                               num_workers=0,
                               collate_fn=DittoDataset.pad)

    # prediction
    all_probs = []
    all_logits = []
    with torch.no_grad():
        # print('Classification')
        for i, batch in enumerate(iterator):
            batch1=[]
            for element in batch:
                batch1.append(element.to(model.device))
            x, _ = batch1
            logits = model(x)
            probs = logits.softmax(dim=1)[:, 1]
            all_probs += probs.cpu().numpy().tolist()
            all_logits += logits.cpu().numpy().tolist()

    if threshold is None:
        threshold = 0.5

    pred = [1 if p > threshold else 0 for p in all_probs]
    return pred, all_logits

def predict(input_path, output_path, config,
            model,
            batch_size=1024,
            summarizer=None,
            lm='distilbert',
            max_len=256,
            dk_injector=None,
            threshold=None):
    """Run the model over the input file containing the candidate entry pairs

    Args:
        input_path (str): the input file path
        output_path (str): the output file path
        config (Dictionary): task configuration
        model (DittoModel): the model for prediction
        batch_size (int): the batch size
        summarizer (Summarizer, optional): the summarization module
        max_len (int, optional): the max sequence length
        dk_injector (DKInjector, optional): the domain-knowledge injector
        threshold (float, optional): the threshold of the 0's class

    Returns:
        None
    """
    pairs = []

    def process_batch(rows, pairs, writer):
        predictions, logits = classify(pairs, model, lm=lm,
                                       max_len=max_len,
                                       threshold=threshold)
        # try:
        #     predictions, logits = classify(pairs, model, lm=lm,
        #                                    max_len=max_len,
        #                                    threshold=threshold)
        # except:
        #     # ignore the whole batch
        #     return
        scores = softmax(logits, axis=1)
        for row, pred, score in zip(rows, predictions, scores):
            output = {
                'left': row[0],
                'right': row[1],
                'match': pred,
                'match_confidence': score[int(pred)],
                'samepEn': sum(np.log2(score) * score * (-1))
                 },
            writer.write(output)

    # input_path can also be train/valid/test.txt
    # convert to jsonlines
    if '.txt' in input_path:
        with jsonlines.open(input_path + '.jsonl', mode='w') as writer:
            for line in open(input_path,'r',encoding='utf-8'):
                writer.write(line.split('\t')[:2])
        input_path += '.jsonl'

    # batch processing
    start_time = time.time()
    with jsonlines.open(input_path) as reader,\
         jsonlines.open(output_path, mode='w') as writer:
        pairs = []
        rows = []
        for idx, row in tqdm(enumerate(reader)):
            pairs.append(to_str(row[0], row[1], summarizer, max_len, dk_injector))
            rows.append(row)
            if len(pairs) == batch_size:
                process_batch(rows, pairs, writer)
                pairs.clear()
                rows.clear()

        if len(pairs) > 0:
            process_batch(rows, pairs, writer)

    run_time = time.time() - start_time
    run_tag = '%s_lm=%s_dk=%s_su=%s' % (config['name'], lm, str(dk_injector != None), str(summarizer != None))
    os.system('echo %s %f >> log.txt' % (run_tag, run_time))


def tune_threshold(config, model, hp):
    """Tune the prediction threshold for a given model on a validation set"""
    validset = config['validset']
    task = hp.task

    # summarize the sequences up to the max sequence length
    set_seed(123)
    summarizer = injector = None
    if hp.summarize:
        summarizer = Summarizer(config, lm=hp.lm)
        validset = summarizer.transform_file(validset, max_len=hp.max_len, overwrite=True)

    if hp.dk is not None:
        if hp.dk == 'product':
            injector = ProductDKInjector(config, hp.dk)
        else:
            injector = GeneralDKInjector(config, hp.dk)

        validset = injector.transform_file(validset)

    # load dev sets
    try:
        valid_dataset = DittoDataset(validset,
                                     max_len=hp.max_len,
                                     lm=hp.lm)
    except:
        print("构建验证集出错了")

    # print(valid_dataset[0])

    valid_iter = data.DataLoader(dataset=valid_dataset,
                                 batch_size=64,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=DittoDataset.pad)

    # acc, prec, recall, f1, v_loss, th = eval_classifier(model, valid_iter,
    #                                                     get_threshold=True)
    f1, th = evaluate(model, valid_iter, threshold=None)

    # verify F1
    set_seed(123)
    predict(validset, "../temp/tmp.jsonl", config, model,
            summarizer=summarizer,
            max_len=hp.max_len,
            lm=hp.lm,
            dk_injector=injector,
            threshold=th)

    predicts = []
    with jsonlines.open("../temp/tmp.jsonl", mode="r") as reader:
        for line in reader:
            predicts.append(int(line[0]['match']))
    #os.system("rm tmp.jsonl")# Linux命令
    os.remove('../temp/tmp.jsonl')

    labels = []
    with open(validset,'r',encoding='utf-8') as fin:
        for line in fin:
            labels.append(int(line.split('\t')[-1]))

    real_f1 = sklearn.metrics.f1_score(labels, predicts)
    print("load_f1 =", f1)
    print("real_f1 =", real_f1)

    return th



def load_model(task, path, lm, use_gpu, fp16=True):
    """Load a model for a specific task.

    Args:
        task (str): the task name
        path (str): the path of the checkpoint directory
        lm (str): the language model
        use_gpu (boolean): whether to use gpu
        fp16 (boolean, optional): whether to use fp16

    Returns:
        Dictionary: the task config
        MultiTaskNet: the model
    """
    # load models
    checkpoint = os.path.join(path, task, 'model.pt')
    if not os.path.exists(checkpoint):
        raise ModelNotFoundError(checkpoint)

    configs = json.load(open('../configs.json'))
    configs = {conf['name'] : conf for conf in configs}
    config = configs[task]
    config_list = [config]

    if use_gpu:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = 'cpu'

    model = DittoModel(device=device, lm=lm)

    saved_state = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    model.load_state_dict(saved_state['model'])
    model = model.to(device)

    # if fp16 and 'cuda' in device:
    #     model = amp.initialize(model, opt_level='O2')

    return config, model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='Structured/Beer')
    parser.add_argument("--output_path", type=str, default='../output/matched_small.jsonl')
    parser.add_argument("--lm", type=str, default='distilbert')
    parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
    parser.add_argument("--fp16", dest="fp16", action="store_true")
    parser.add_argument("--checkpoint_path", type=str, default='../checkpoints/')
    parser.add_argument("--dk", type=str, default=None)
    parser.add_argument("--summarize", dest="summarize", action="store_true")
    parser.add_argument("--max_len", type=int, default=256)
    # 因为新增了训练的部分 因此需要添加一些参数
    parser.add_argument("--size", type=int, default=None)
    parser.add_argument("--da", type=str, default=None)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--n_epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=3e-5)
    # Target Domain
    parser.add_argument("--target_task", type=str, default='Structured/Beer')
    # percent控制数据集规模
    parser.add_argument("--percent", type=int, default=None)
    # al决定是否用主动学习
    parser.add_argument("--al", dest="al", action="store_true")
    #IM损失相关
    parser.add_argument('--gent', type=bool, default=True)# L_div
    parser.add_argument('--ent', type=bool, default=True)# L_ent

    hp = parser.parse_args()

    # set seeds
    set_seed(123)

    summarizer = dk_injector = None


    #根据checkpoint 加载源域模型
    config, save_model = load_model(hp.task, hp.checkpoint_path,
                       hp.lm, hp.use_gpu, hp.fp16)

    #根据获取目标域数据集
    target_task = hp.target_task#"wdc_cameras_small"
    configstemp = json.load(open('../configs.json'))
    configstemp = {conf['name']: conf for conf in configstemp}
    # 定义dataconfig
    dataconfig = configstemp[target_task]
    #从训练集中选择带标签的数据微调源模型，并在测试集上记录F1分数
    trainset = '../' + dataconfig['trainset']
    validset = '../' + dataconfig['validset']
    testset = '../' + dataconfig['testset']


    if hp.use_gpu:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = 'cpu'

    # ======= 辅助 network =======
    class Feature_Extractor(nn.Module):
        def __init__(self,device='cuda'):
            super(Feature_Extractor, self).__init__()
            self.model = save_model.bert
            self.device = device

        # 特征提取 encode_forward
        def forward(self, x1):
            x1 = x1.to(self.device)  # (batch_size, seq_len)
            enc = self.model(x1)[0][:, 0, :]  # 获取CLS向量
            return enc
    class Classifier(nn.Module):
        def __init__(self,device='cuda'):
            super(Classifier, self).__init__()
            self.model = save_model.fc
            self.device = device
        # 分类器 fc_forword
        def forward(self, enc):
            prob = self.model(enc)
            res = prob.softmax(dim=1)
            return res

    # ======= network architecture =======
    # TargetNet = netF + netC
    class TargetNet(nn.Module):
        def __init__(self,device='cuda',alpha_aug = 0.8):
            super(TargetNet, self).__init__()
            self.device = device
            self.alpha_aug = alpha_aug
            self.netF = Feature_Extractor(self.device)
            self.netC = Classifier(self.device)

        # 整体 forward 全局使用时:feature_extractor + classifier
        def forward(self, x1):
            x1 = x1.to(self.device)  # (batch_size, seq_len)
            enc = self.netF(x1)# 获取CLS向量
            return self.netC(enc)  # .squeeze() # .sigmoid()

    # ======= 初始化 target network =======
    tragetNet = TargetNet()
    netF = (tragetNet.netF).cuda()
    netF.model.load_state_dict(save_model.bert.state_dict())

    netC = (tragetNet.netC).cuda()
    netC.model.load_state_dict(save_model.fc.state_dict())

    tragetNet.eval()# 设置为评估模型

    # ======= Befor SFDA in ER =======
    # 假设有训练数据提供小部分验证集来调节threshold
    threshold = None
    threshold = 0.5#等具体操作的时候注释掉 这里这是方便调试
    if threshold is None:
        threshold = tune_threshold(config, save_model, hp)
    # 在trainset上执行predict 结果保存在matched_small_fix.jsonl'
    # 用于获取target训练样本的熵值 && 伪标签添加
    hp.output_path = '../output/matched_small_fix.jsonl'
    if not os.access(hp.output_path, os.F_OK):  # 如果hp.output_path路径不存在
        #执行预测，并且获取训练集上样本的熵值
        # 在trainset上应用Source model获取样本熵值
        predict(trainset, hp.output_path, config, save_model,
                summarizer=summarizer,
                max_len=hp.max_len,
                lm=hp.lm,
                dk_injector=dk_injector,
                threshold=threshold)
    SamepEn = []
    with jsonlines.open(hp.output_path, mode="r") as reader:
        for line in reader:
            SamepEn.append(float('%.4f' % (line[0]['samepEn'])))

    # 在testset上应用Source model获取评估模型效果
    hp.output_path = '../output/matched_small_fix_test.jsonl'
    if not os.access(hp.output_path, os.F_OK):  # 如果hp.output_path路径不存在
        predict(testset, hp.output_path, config, save_model,
                summarizer=summarizer,
                max_len=hp.max_len,
                lm=hp.lm,
                dk_injector=dk_injector,
                threshold=threshold)

    print("Befor SFDA:在测试集计算预测准确率")
    print("计算F1")
    Lastpredicts = []
    with jsonlines.open(hp.output_path, mode="r") as reader:
        for line in reader:
            Lastpredicts.append(int(line[0]['match']))
    Lastlabels = []
    with open(testset, 'r', encoding='utf-8') as fin:#原hp.input_path改为trainset
        for line in fin:
            Lastlabels.append(int(line.split('\t')[-1]))

    predicts_f1 = sklearn.metrics.f1_score(Lastpredicts, Lastlabels)
    print("predicts_f1 in "+ str(testset) + " is \n", predicts_f1)
    rightnum = 0
    ind = 0
    while ind < len(Lastlabels):
        if Lastlabels[ind] == Lastpredicts[ind]:
            rightnum = rightnum + 1
        ind = ind + 1
    print("predicts_acc in "+ str(testset) + " is \n", float(('%.4f') % (rightnum / len(Lastlabels))))

    # ======= Run SFDA in ER =======
    # load train/dev/test sets
    train_dataset = DittoDataset(trainset, lm=hp.lm,
                                 max_len=hp.max_len,
                                 size=hp.size,
                                 da=hp.da,
                                 en=None,
                                 percent=None)
    valid_dataset = DittoDataset(validset, lm=hp.lm)
    test_dataset = DittoDataset(testset, lm=hp.lm)

    # 方法2 SHOT
    shot_train(train_dataset,valid_dataset,test_dataset,hp,threshold, tragetNet, netF, netC)